{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## **演示0503：梯度下降算法**"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### **计算过程：**\n",
    "假设下列函数定义：\n",
    "$ F(w) = \\frac{1}{2m} \\sum_{i=1}^{m}[f_w(x^{(i)}) - y^{(i)}]^2 = \\frac{1}{2m} \\sum_{i=1}^{m}(w_0+w_1 x^{(i)} - y^{(i)})^2 $\n",
    "其中：\n",
    "* $m$:样本数据点数量\n",
    "* $f_w$ :以$w$为系数的拟合函数\n",
    "* $x^{(i)}$:第$i$个数据的横坐标值(数据中的自变量)\n",
    "* $y^{(i)}$:第$i$个数据的纵坐标值(数据中的因变量)\n",
    "* $f_w(x^{(i)})$:将$x^{(i)}$ 代入到拟合函数计算的结果\n",
    "* $F(w)$:也称为**目标函数**或**成本函数**  \n",
    "根据最小二乘法可知,针对函数$F(w)$,要设法求出最优的$w$,从而使得$F(w)$的值最小。求解过程可以使用梯度下降法来实现。\n",
    "\n",
    "对于只有一个自变量($x$)的情形，此时最优的$w$实际上包括$w_0$和$w_1$，使用梯度下降算法的基本步骤如下：\n",
    "1. 给定目标函数$F(w)$(实际上也是$F(w_0, w_1)$)和学习速率(learning rate): $ \\alpha $\n",
    "2. 随机初始化$w_0$和$w_1$ ,计算出$F(w_0 , w_1)$\n",
    "3. 计算出能够使$F(w)$下降最快的$w$偏移量(梯度)：$\\Delta w$，叠加到初始的$w_0$和$w_1$上\n",
    "  * 针对$w_0$和$w_1$，分别计算梯度$\\nabla w_0$和$\\nabla w_1$\n",
    "  * 按下列公式调整$w_0$和$w_1$:\n",
    "    * $ w_0 = w_0 - \\alpha \\nabla w_0 $\n",
    "    * $ w_1 = w_1 - \\alpha \\nabla w_1 $\n",
    "    * 请注意，$w_0$和$w_1$的变化方向要与导数的方向相反\n",
    "  * 利用调整后的$w_0$和$w_1$，再次计算$F(w_0, w_1)$\n",
    "4. 循环第3步过程，直到最近两次计算出来的目标函数的差值小于某个误差限\n",
    "5. 此时$w_0$和$w_1$就是目标函数的最优解  \n",
    "\n",
    "注意：\n",
    "* 可能需要尝试及合理选择$\\alpha $。$\\alpha $太大，有可能越过了最优点；$\\alpha $小，收敛速度会非常慢\n",
    "* 当前主流的梯度算法，能够随着梯度下降循环的进行，动态调整$\\alpha $值，从而能够较快的收敛，又能避免跨越最优点的情形\n",
    "* 找到的最优点，也有可能是局部最优点而非全局最优点"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### **算法原理示意**\n",
    "下图展示了梯度下降算法的工作原理：\n",
    "\n",
    "![](../images/050301.png)\n",
    "\n",
    "* 选择一个初始点(即给定初始的$w$)，作为当前工作点\n",
    "* 计算该点出的梯度值(导数$\\dfrac{\\partial F}{\\partial w}$)，沿着梯度最大的方向，移动当前工作点到新的位置\n",
    "* 经过若干次移动后，如果新的工作点位置对应的$F(w)$值与上一个位置几乎没有差异，则可以认为找到了极值点\n",
    "* 请注意，一般来说，采用梯度下降法求取极小值（而不是极大值）"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### **批量梯度下降法**\n",
    "> 梯度下降法中最重要的一步是计算某个工作点的梯度$\\nabla w$。有多种用于计算$\\nabla w$的方法,其计算结果也会有所不同\n",
    "> * 批量梯度下降法是指：在计算$\\nabla w$时,要把所有样本点数据的梯度都计算进去,也就是直接针对$F(w)$来计算偏导(注意$F(w)$中包含了所有样本点的数据值)：\n",
    ">\n",
    ">  $ \\begin{aligned} \\\\ \n",
    "\\nabla w_0 & =\\frac{\\partial F(w)}{\\partial w_0} \\\\ \\\\ & =\\dfrac{\\partial \\frac{1}{2m} \\sum_{i=1}^{n}(w_0+w_1 x^{(i)}-y^{(i)})^2}{\\partial w_0} \\\\ \\\\\n",
    "& = \\dfrac{\\frac{1}{2m} \\cdot 2 \\cdot \\sum_{i=1}^m [(w_0+w_1 x^{(i)}-y^{(i)}) \\cdot \\partial {(w_0+w_1 x^{(i)}-y^{(i)})}]}{\\partial w_0} \\\\ \\\\\n",
    "& = \\dfrac{1}{m} \\sum_{i=1}^m [(w_0+w_1 x^{(i)}-y^{(i)}) \\cdot (1+0+0)] \\\\ \\\\ & =\\dfrac{1}{m} \\sum_{i=1}^m (w_0+w_1 x^{(i)}-y^{(i)})\n",
    "\\end{aligned} $\n",
    ">\n",
    ">  $ \\begin{aligned}\n",
    "\\nabla w_1 & =\\frac{\\partial F(w)}{\\partial w_1} \\\\ \\\\ & =\\dfrac{\\partial \\frac{1}{2m} \\sum_{i=1}^{n}(w_0+w_1 x^{(i)}-y^{(i)})^2}{\\partial w_1} \\\\ \\\\\n",
    "& = \\dfrac{\\frac{1}{2m} \\cdot 2 \\cdot \\sum_{i=1}^m [(w_0+w_1 x^{(i)}-y^{(i)}) \\cdot \\partial {(w_0+w_1 x^{(i)}-y^{(i)})}]}{\\partial w_1} \\\\ \\\\\n",
    "& = \\dfrac{1}{m} \\sum_{i=1}^m [(w_0+w_1 x^{(i)}-y^{(i)}) \\cdot (0 + x^{(i)} + 0)] \\\\ \\\\ & =\\frac{1}{m} \\sum_{i=1}^m [(w_0+w_1 x^{(i)}-y^{(i)}) \\cdot x^{(i)}] \n",
    "\\end{aligned} $\n",
    ">\n",
    "> * 在梯度下降算法的每一步，计算出$\\nabla w$后，代入下式中进行下一步计算：\n",
    ">\n",
    ">  $ w_0 = w_0-\\alpha \\nabla w_0 $\n",
    ">\n",
    ">  $ w_1 = w_1-\\alpha \\nabla w_1 $"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### **案例1：使用批量梯度下降算法拟合直线**\n",
    "> 待拟合的二维平面数据点：(6, 7), (8, 9), (10, 13), (14, 17.5), (18, 18)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    ">**步骤1：定义目标函数、梯度函数**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "\n",
    "# 定义目标函数，该目标函数实际上是残差平方和。相当于采用了最小二乘法。\n",
    "def target_function(W, X, Y):                           \n",
    "    w1, w0 = W\n",
    "    return np.sum((w0 + X * w1 - Y) ** 2) / (2 * len(X))  # 应使用np.sum，而不要使用sum。np.sum支持向量/矩阵运算\n",
    "\n",
    "# 定义梯度\n",
    "# 根据目标函数，分别对x0和x1求导数(梯度)，并累计各点的导数之平均值\n",
    "def gradient_function(W, X, Y):                         \n",
    "    w1, w0 = W\n",
    "    w0_grad = np.sum(w0 + X * w1 - Y) / len(X)      # 对应w0的导数\n",
    "    w1_grad = X.dot(w0 + X * w1 - Y) / len(X)       # 对应w1的导数。注意采用向量运算\n",
    "    return np.array([w1_grad, w0_grad]) "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "> **步骤2：定义批量梯度下降算法计算函数**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 定义最简单的批量梯度下降算法函数\n",
    "def batch_gradient_descent(target_fn, gradient_fn, init_W, X, Y, learning_rate=0.001, tolerance=1e-12):\n",
    "    W = init_W\n",
    "    target_value = target_fn(W, X, Y)                   # 计算当前w下的F(w)值\n",
    "    while True:\n",
    "        gradient = gradient_fn(W, X, Y)                 # 计算梯度\n",
    "        next_W = W - gradient * learning_rate           # 调整W。向量计算，同时调整了w1和w0\n",
    "        next_target_value = target_fn(next_W, X, Y)     # 计算新theta下的cost值\n",
    "        if abs(next_target_value - target_value) < tolerance:   # 如果两次计算之间的误差小于tolerance，则表明已经收敛\n",
    "            return next_W\n",
    "        else:\n",
    "            W, target_value = next_W, next_target_value         # 继续进行下一轮计算"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    ">**步骤3：执行批量梯度下降操作，获取最优权重参数解**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "最优的w1和w0: 0.9763122679381336 1.9652710600118533\n"
     ]
    }
   ],
   "source": [
    "# 待拟合的数据点\n",
    "x = np.array([6, 8, 10, 14, 18])\n",
    "y = np.array([7, 9, 13, 17.5, 18])\n",
    "\n",
    "np.random.seed(0)\n",
    "init_W = np.array([np.random.random(), np.random.random()])      # 随机初始化W值\n",
    "w1, w0 = batch_gradient_descent(target_function, gradient_function, init_W, x, y) \n",
    "print(\"最优的w1和w0:\", w1, w0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    ">**步骤4：根据最优解做出拟合曲线，并作图比较**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAAEKCAYAAAD9xUlFAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAGjRJREFUeJzt3Xt0VPW99/HPl0sEAQW5WFAU6iVA0QfBKjkoDXIQFbxUpUdwWTk8TaJQr+hRpAfReqE1iDcuQrXQp9AWRAuJCEoggDUUDVJAEEVBUVIoCppQSCD5Pn9k3IkIQiAzeyZ5v9ZiZWZPZvbH39rwcV/mt83dBQCAJNUJOwAAIH5QCgCAAKUAAAhQCgCAAKUAAAhQCgCAQNRKwczamtliM1tvZu+Z2R2R5aPN7HMzWxX5c0W0MgAAqsai9T0FM2stqbW7rzSzJpLyJV0j6WeSitw9MyorBgActXrR+mB3L5BUEHlcaGbrJZ0SrfUBAI5d1PYUvrUSs3aSlkrqLOluSYMlfS3pHUnD3X3nQd6TLildkho0aNDttNNOi3rORFBWVqY6dTgVJDEWlTEWFRiLCh988MEOd29ZlfdEvRTMrLGkJZIedfeXzexkSTskuaRfq/wQ05Dv+4zk5GTfsGFDVHMmitzcXKWmpoYdIy4wFhUYiwqMRQUzy3f386vynqjWqZnVlzRb0nR3f1mS3H2bu5e6e5mkKZIuiGYGAMCRi+bVRybpBUnr3f3JSstbV/q1n0paG60MAICqidqJZkk9JN0kaY2ZrYose0DSQDProvLDR5slZUQxAwCgCqJ59dGbkuwgL82L1joBAMeGU/QAgAClAAAIUAoAgAClAAAIUAoAgAClAAAIUAoAgAClAAAIUAoAgAClAAAIUAoAgAClAAAIUAoAgAClAAAIUAoAgAClAAAIUAoAgAClAAAIUAoAgAClAAAIUAoAgAClAAAIUAoAgAClAAAIUAoAgAClAAAIUAoAgAClAAAIUAoAgAClAAAIUAoAgAClAAAIUAoAgAClAAAIUAoAgEDUSsHM2prZYjNbb2bvmdkdkeUnmdkbZvZh5GezaGUAAFRNNPcU9ksa7u4dJXWXNMzMOkm6X1KOu58lKSfyHAAQB6JWCu5e4O4rI48LJa2XdIqkqyVNi/zaNEnXRCsDAKBqzN2jvxKzdpKWSuos6VN3b1rptZ3u/p1DSGaWLildklq2bNlt5syZUc+ZCIqKitS4ceOwY8QFxqICY1GBsajQq1evfHc/vyrviXopmFljSUskPeruL5vZriMphcqSk5N9w4YNUc2ZKHJzc5Wamhp2jLjAWFRgLCowFhXMrMqlENWrj8ysvqTZkqa7+8uRxdvMrHXk9daStkczA4Bwubv27NsTdgwcoWhefWSSXpC03t2frPTSXEk3Rx7fLGlOtDIACNfmXZvVb0Y//fyvPw87Co5QNPcUeki6SdIlZrYq8ucKSWMk9TGzDyX1iTwHUIPsL9uvcXnj9KMJP9LST5bqorYXKRbnL3Hs6kXrg939TUl2iJd7R2u9AML1bsG7SstKU35Bvvqd1U8T+k3QaSeeFnYsHKGolQKA2mV3yW6Nzh2tccvHqcXxLfSX6/+iAZ0GqPxIMhIFpQDgmL3+0eu6JfsWbdq1SWld0/Sb//yNmjVksoJERCkAOGr/2v0v3f363frj6j8quXmylgxeop6n9ww7Fo4BpQCgytxdf/jHH3T363ersLhQo3qO0oiLR6hBvQZhR8MxohQAVMnGLzcqIztDizYtUo+2PTT5ysnq1LJT2LFQTSgFAEdkX+k+jc0bq4eWPKSkukma2G+i0rulq44xA39NQikAOKwVn69QWlaaVm9brWs7XqtnL39WbZq0CTsWooBSAHBIhcWF+tWiX+nZFeUl8Mp/vaJrOjCxcU1GKQA4qKwNWRo6b6g+//pzDfvxMD3a+1GdcNwJYcdClFEKAL6loLBAd8y/Q7PWzdKPWv5IM4fMVErblKP+vLy8vGDm0pSUo/8cxAalAECSVOZlemHlC7r3jXu1d/9ePdLrEd3b414l1U066s/My8tT7969VVJSoqSkJOXk5FAMcY5SAKD3d7yv9Kx0Lft0mVLbper5/s/r7OZnH/Pn5ubmqqSkRKWlpSopKVFubi6lEOcoBaAWK95frDFvjtFjbz6mRvUb6cWrXtTgLoOrbb6i1NRUJSUlBXsK3Pwm/lEKQC315qdvKj0rXet3rNfAzgM1ru84ndz45GpdR0pKinJycjinkEAoBSBOxOqE7K69uzRi4QhNyp+k0088XfMGzdPlZ10etfWlpKRQBgmEUgDiQCxOyLq7Xl7/sm577TZt271Nd3e/Ww/1ekiNk7jJPSrw/XQgDhzshGx12vLVFl3zl2t0/azr9YPGP9CKX6zQ2L5jKQR8B3sKQByI1gnZ0rJSTXxnokbkjFBpWame6POE7ux+p+rV4a8+Do4tA4gD0Tghu2bbGqVnp2v5Z8t16RmXalK/SWrfrH01pEVNRikAcaK6Tsju2bdHjyx9RL9967dq2qCp/vjTP2rQOYO4LSaOCKUA1CCLNi1SRnaGNn65UYO7DFZmn0w1P7552LGQQCgFoAb4at9XGjJniH6/6vc6o9kZWnjTQvX+Ye+wYyEBUQpAAnN3/XntnzX07aEq3F+o+3vcr1E/GaWG9RuGHQ0JilIAEtTmXZt166u3av7G+erQpIOW3LhE5558btixkOAoBSDB7C/br6eXP61RuaNUx+romcueUad/d6IQUC0oBSCBrCxYqbSsNK0sWKn+Z/fXhCsmqO2Jbav9y26ovSgFIAHsLtmt0bmjNW75OLU4voVmXj9T13e6nstMUe0oBSDOLdi4QLe8eos279qs9K7pGvOfY9SsYbOwY6GGohSAOLV993bdteAuzVgzQx1adNDSwUt18ekXhx0LNRylAMQZd9e0f0zT8NeHq7C4UA/+5EGNuGiEjqt3XNjRUAtQCkAc2fjlRmVkZ2jRpkXq0baHJl85WZ1adgo7FmoRSgGIA/tK9ynzrUw9vPRhJdVN0qR+k5TWLU11jNntEVuUAhCyv3/2d6VlpWnN9jW6ruN1eubyZ9SmSZuwY6GWohSAkBQWF2rkopF6bsVzatOkjebcMEdXJV8VdizUclHbNzWzF81su5mtrbRstJl9bmarIn+uiNb6gXiWtSFLnSZ00nMrntOwHw/TumHrKATEhWjuKUyV9JykPxywfJy7Z0ZxvUDcKigs0O3zb9dL615S51adNWvALHU/tXvYsYBA1ErB3ZeaWbtofT6QSMq8TFPyp+i+hfdp7/69evSSR3Xvf9yr+nXrhx0N+BZz9+h9eHkpZLt758jz0ZIGS/pa0juShrv7zkO8N11SuiS1bNmy28yZM6OWM5EUFRWpcWNuti4lzlh8svsTjf1grNZ8vUbnNT1Pd591t049/tRqXUeijEUsMBYVevXqle/u51flPbEuhZMl7ZDkkn4tqbW7Dznc5yQnJ/uGDRuiljORfHMPX8T/WBTvL9aYN8fosTcfU6P6jTT20rEa3GVwVOYrivexiCXGooKZVbkUYnr1kbtv++axmU2RlB3L9QOxsuyTZUrPTtf7O97XoHMGaVzfcWrVqFXYsYDDiuk3Y8ysdaWnP5W09lC/CySiXXt3KSMrQz2n9tSefXs0b9A8Tb92OoWAhBG1PQUz+5OkVEktzOwzSQ9KSjWzLio/fLRZUka01g/Ekrtr9vrZuu2127R993YNTxmuh1IfUqOkRmFHA6okmlcfDTzI4heitT4gLFu+2qJfvvZLzd0wV+f94DxlD8xWtzbdwo4FHBW+0QwcpdKyUk14e4IeWPSASstKldknU3d0v0P16vDXComLrRc4Cqu3rVZaVppWfL5Cfc/oq4n9Jqp9s/ZhxwKOGaUAVMGefXv066W/1hNvPaFmDZpp+rXTNbDzQG6LiRqDUgCO0KJNi5SRnaGNX27U4C6DldknU82Pbx52LKBaUQrAYXzx7y90zxv3aOqqqTqj2RlaeNNC9f5h77BjAVFBKQCH4O6asWaG7lxwp3bt3aURF43Q//b8XzWs3zDsaEDUUArAQWzauUm3vnqrFny0QBeccoGmXDlF5558btixgKijFIBK9pft19PLn9ao3FGqY3X0zGXPaOiPh6punbphRwNiglIAIvK35istK03v/vNdXXn2lRp/xXi1PbFt2LGAmKIUUOvtLtmtUYtH6am/P6VWjVpp1oBZuq7jdVxmilqJUkBCysvL0/Tp03XccccpJSXlqD9n/sb5uiX7Fn3y1SdK75qu3/T5jZo2aFqNSYHEQikg4eTl5al3794qLi7W9OnTlZOTU+Vi2L57u+5acJdmrJmhDi06aOngpbr49IujlBhIHDGdOhuoDrm5uSopKVFZWZlKSkqUm5t7xO91d/3+3d+rw3MdNOu9WXrwJw9qVcYqCgGIYE8BCSc1NVVJSUkqLi5WUlLSEd9l68MvPlRGdoYWb16sHm17aMqVU9SxZcfohgUSzCH3FMxsXuR2mkBcSUlJUU5OjoYMGXJEh472le7TY8se0zkTz1F+Qb4m9Zukpf+9lEIADuL79hSmSnrdzKZJ+q2774tNJODwUlJSVFxcfNhCWP7ZcqVlpWnt9rW6vtP1evqyp9WmSZsYpQQSzyFLwd1nmtmrkkZJesfM/p+kskqvPxmDfMBR+br4a43MGanxb49XmyZtNOeGOboq+aqwYwFx73DnFPZJ2i3pOElNVKkUgHg1d8NcDX11qLYWbtWwHw/To70f1QnHnRB2LCAhHLIUzOwySU9Kmiupq7v/O2apgKNQUFig2167TbPXz1bnVp310s9eUvdTu4cdC0go37enMFLSAHd/L1ZhgKNR5mWakj9F9y28T3v379Vjlzyme/7jHtWvWz/saEDC+b5zCly4jbi37l/rlJ6Vrr9t+Zt6teul5/s/r7OanxV2LCBh8T0FJKTi/cWaunmqZiybocZJjfXiVS9qcJfBzFcEHCNKAQln2SfLlJ6drvd3vK9B5wzSuL7j1KpRq7BjATUC01wgYezau0sZWRnqObWn9uzbozGdx2j6tdMpBKAaUQqIe+6uWe/NUsfxHfW7d3+n4SnD9d7Q93Rh8wvDjgbUOBw+Qlzb8tUWDZ03VNkfZOu8H5yn7IHZ6tamW9ixgBqLUkBcKi0r1fi3x2vkopEq8zJl9snUHd3vUL06bLJANPE3DHFn9bbVSstK04rPV6jvGX01sd9EtW/WPuxYQK1AKSBu7Nm3Rw8veViZeZlq1qCZpl87XQM7D+QyUyCGKAXEhZyPc5SRnaGPdn6kwV0GK7NPppof3zzsWECtQykgVF/8+wsNf324pv1jms486Uzl/DxHl7S/JOxYQK1FKSAU7q4Za2bozgV3atfeXXrgogf0q56/UsP6DcOOBtRqlAJibtPOTbr11Vu14KMFuvCUCzX5ysk69+Rzw44FQJQCYmh/2X49tfwpjVo8SnXr1NWzlz+rW8+/VXXr1A07GoAISgExkb81X2lZaXr3n+/qquSr9Nzlz6ntiW3DjgXgAFGb5sLMXjSz7Wa2ttKyk8zsDTP7MPKzWbTWj/hQVFKk4QuG64LfXaCCogK9NOAl/fW//kohAHEqmnMfTZV02QHL7peU4+5nScqJPEcN9dqHr6nzhM56cvmTSuuapvXD1uu6TtfxvQMgjkWtFNx9qaQvD1h8taRpkcfTJF0TrfUjPNuKtmnQ7EG6YsYVOr7+8Vr238s0qf8kNW3QNOxoAA7D3D16H27WTlK2u3eOPN/l7k0rvb7T3Q96CMnM0iWlS1LLli27zZw5M2o5E0lRUZEaN24cdoyDcne99s/XNOnjSdpbulc3nnajBp42UEl1kqKyvngei1hjLCowFhV69eqV7+7nV+U9cXui2d0nS5osScnJyZ6amhpuoDiRm5ureByLD7/4UBnZGVq8ebEuOu0iTe4/WR1bdozqOuN1LMLAWFRgLI5NrEthm5m1dvcCM2staXuM149qVlJaosy3MvXwkofVoF4DPd//ef2i6y9Ux7hVB5CIYl0KcyXdLGlM5OecGK8f1Wj5Z8uVlpWmtdvXakCnAXr6sqfVuknrsGMBOAZRKwUz+5OkVEktzOwzSQ+qvAxmmtn/lfSppAHRWj+i5+vir/VAzgOa8PYEnXLCKZp7w1xdmXxl2LEAVIOolYK7DzzES72jtU5E35z352jYvGHaWrhVt11wmx655BE1Oa5J2LEAVJO4PdGM+LK1cKtuf+12zV4/W+e0OkezfzZbF57KPZKBmoZSwPcq8zJNzp+s+xbep5LSEj3e+3ENTxmu+nXrf+d38/Lygis/UlJSQkgL4FhRCjikdf9ap/SsdP1ty990SftL9Hz/53XmSWce9Hfz8vLUu3dvlZSUKCkpSTk5ORQDkIC4bhDfUby/WA8uflBdJnXR+h3rNfXqqVp408JDFoJUfm14SUmJSktLVVJSotzc3NgFBlBt2FPAtyz9ZKnSs9K14YsNuvGcG/Vk3yfVqlGrw74vNTVVSUlJwZ4CXx4CEhOlAEnSrr279D9v/I+mrJyidk3baf6N89X3zL5H/P6UlBTl5ORwTgFIcJRCLefuemndS7p9/u3avnu77km5R6NTR6tRUqMqf1ZKSgplACQ4SqEW2/LVFg2dN1TZH2Sra+uuenXQq+raumvYsQCEiFKohUrLSjX+7fEauWikyrxMYy8dq9svvF316rA5ALUd/wrUMqu3rVZaVppWfL5Cl515mSb2m6h2TduFHQtAnKAUaok9+/bo4SUPKzMvU80aNNOMa2fohs43cBc0AN9CKdQCCz9eqFuyb9FHOz/SkC5D9MSlT+ikhieFHQtAHKIUarixb43VPW/cozNPOlOLfr5Ivdr3CjsSgDhGKdRw/c/ur517d2rkxSPVsH7DsOMAiHOUQg2X3CJZj1zySNgxACQI5j4CAAQoBQBAgFIAAAQohRouLy9Pjz/+uPLy8sKOAiABcKK5BuPGNwCqij2FGowb3wCoKkqhBvvmxjd169blxjcAjgiHj2owbnwDoKoohRqOG98AqAoOHwEAApQCACBAKQAAApQCACBAKQAAApQCACBAKQAAApQCACBAKQAAApQCACAQyjQXZrZZUqGkUkn73f38MHIAAL4tzLmPern7jhDXDwA4AIePAAABc/fYr9Rsk6SdklzS8+4++SC/ky4pXZJatmzZbebMmbENGaeKiorUuHHjsGPEBcaiAmNRgbGo0KtXr/yqHp4PqxTauPtWM2sl6Q1Jt7n70kP9fnJysm/YsCF2AePYN/dGAGNRGWNRgbGoYGZVLoVQDh+5+9bIz+2SXpF0QRg5AADfFvNSMLNGZtbkm8eSLpW0NtY5AADfFcbVRydLesXMvln/DHefH0IOAMABYl4K7v6xpP8T6/UCAA6PS1IBAAFKAQAQoBQAAAFKAQAQoBQAAAFKAQAQoBQAAAFKAQAQoBQAAAFKAQAQoBQAAAFKAQAQoBQAAAFKAQAQoBQAAAFKAQAQoBQAAAFKAQAQoBQAAAFKAQAQoBQAAAFKAQAQoBQAAAFKAQAQoBQAAAFKAQAQoBQAAAFKAQAQoBQAAAFKAQAQoBQAAAFKAQAQoBQAAAFKAQAQoBQAAIFQSsHMLjOzDWa20czuDyMDAOC7Yl4KZlZX0nhJl0vqJGmgmXWKdQ4AwHeFsadwgaSN7v6xu5dI+rOkq0PIAQA4QL0Q1nmKpC2Vnn8m6cIDf8nM0iWlR54Wm9naGGRLBC0k7Qg7RJxgLCowFhUYiwrJVX1DGKVgB1nm31ngPlnSZEkys3fc/fxoB0sEjEUFxqICY1GBsahgZu9U9T1hHD76TFLbSs9PlbQ1hBwAgAOEUQpvSzrLzNqbWZKkGyTNDSEHAOAAMT985O77zeyXkhZIqivpRXd/7zBvmxz9ZAmDsajAWFRgLCowFhWqPBbm/p3D+QCAWopvNAMAApQCACAQ16XAdBjfZmabzWyNma06mkvNEpmZvWhm2yt/X8XMTjKzN8zsw8jPZmFmjJVDjMVoM/s8sm2sMrMrwswYC2bW1swWm9l6M3vPzO6ILK9128X3jEWVt4u4PacQmQ7jA0l9VH4Z69uSBrr7ulCDhcjMNks6391r3RdzzKynpCJJf3D3zpFlv5X0pbuPifxPQzN3vy/MnLFwiLEYLanI3TPDzBZLZtZaUmt3X2lmTSTlS7pG0mDVsu3ie8biZ6ridhHPewpMh4GAuy+V9OUBi6+WNC3yeJrK/xLUeIcYi1rH3QvcfWXkcaGk9SqfMaHWbRffMxZVFs+lcLDpMI7qP7IGcUmvm1l+ZBqQ2u5kdy+Qyv9SSGoVcp6w/dLMVkcOL9X4QyaVmVk7SedJ+rtq+XZxwFhIVdwu4rkUjmg6jFqmh7t3VfkMs8MihxEASZoo6QxJXSQVSBobbpzYMbPGkmZLutPdvw47T5gOMhZV3i7iuRSYDuMA7r418nO7pFdUfoitNtsWOZb6zTHV7SHnCY27b3P3UncvkzRFtWTbMLP6Kv9HcLq7vxxZXCu3i4ONxdFsF/FcCkyHUYmZNYqcQJKZNZJ0qaTaPnPsXEk3Rx7fLGlOiFlC9c0/ghE/VS3YNszMJL0gab27P1nppVq3XRxqLI5mu4jbq48kKXL51FOqmA7j0ZAjhcbMfqjyvQOpfHqSGbVpPMzsT5JSVT4t8jZJD0r6q6SZkk6T9KmkAe5e40/AHmIsUlV+iMAlbZaU8c1x9ZrKzC6StEzSGkllkcUPqPxYeq3aLr5nLAaqittFXJcCACC24vnwEQAgxigFAECAUgAABCgFAECAUgAABCgFoIoiM1JuMrOTIs+bRZ6fHnY24FhRCkAVufsWlU8fMCayaIykye7+SXipgOrB9xSAoxCZUiBf0ouS0iSdF5nNF0ho9cIOACQid99nZvdKmi/pUgoBNQWHj4Cjd7nKZ57sHHYQoLpQCsBRMLMuKr8rYHdJdx0w8RiQsCgFoIoiM1JOVPmc9Z9KekJSrbkNJmo2SgGoujRJn7r7G5HnEyR1MLOfhJgJqBZcfQQACLCnAAAIUAoAgAClAAAIUAoAgAClAAAIUAoAgAClAAAI/H9wxrB8koUvBAAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<matplotlib.figure.Figure at 0x2db32199b38>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "def initPlot():\n",
    "    plt.figure()\n",
    "    plt.xlabel('X')\n",
    "    plt.ylabel('Y')\n",
    "    plt.axis([0, 25, 0, 25])\n",
    "    plt.grid(True)\n",
    "    return plt\n",
    "\n",
    "plt = initPlot()\n",
    "plt.plot(x, y, 'k.')\n",
    "plt.plot(x,  w0 + w1 * x, 'g-')     # 绘制最优拟合曲线\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### **案例2：使用批量梯度下降算法拟合多维数据**\n",
    "* 待拟合的数据点：\n",
    " * 样本点对应的x值：[[6, 2], [8, 1], [10, 0], [14, 2], [18, 0]])\n",
    " * 样本点对应的y值：[19, 21, 23, 43, 47])\n",
    "* 上述数据点是根据函数$y=3*x_1+4*x_2-7$生成的"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    ">**步骤1：定义样本数据，并扩展一个全为1的列**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[ 6.  2.  1.]\n",
      " [ 8.  1.  1.]\n",
      " [10.  0.  1.]\n",
      " [14.  2.  1.]\n",
      " [18.  0.  1.]]\n"
     ]
    }
   ],
   "source": [
    "# 下列数据是按照y=3*x1+4*x2-7函数生成的\n",
    "x_origin = np.array([[6, 2], [8, 1], [10, 0], [14, 2], [18, 0]])\n",
    "y = np.array([19, 21, 23, 43, 47])\n",
    "x_ext = np.c_[x_origin, np.ones(len(x_origin))]         # 追加全是1的数据列\n",
    "print(x_ext)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    ">**步骤2：定义目标函数和梯度函数**\n",
    "* 在计算$\\nabla w$时，可以一次性计算所有$w$分量的梯度：$\\nabla w=X^T \\cdot (Xw-y)$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def target_function(W, X, Y):                           # 目标函数\n",
    "    temp = X.dot(W) - Y                                 # X.dot(W)可视为拟合曲线计算得出的结果向量；temp则是残差\n",
    "    return temp.dot(temp) / (2 * len(X))                # 实际上就是残差平方和除以2m\n",
    "\n",
    "def gradient_function(W, X, Y):                         # 梯度计算函数\n",
    "    length = len(X)\n",
    "    W_grad = (X.T).dot(X.dot(W) - Y) / length           # 仔细对照批量梯度下降法的梯度计算该公式理解此代码\n",
    "    return W_grad"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    ">**步骤3：定义多元梯度下降算法计算函数**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def batch_gradient_descent(target_fn, gradient_fn, init_W, X, Y, learning_rate=0.001, tolerance=1e-12):\n",
    "    \"\"\"支持多变量的批量梯度下降法\"\"\"\n",
    "    # 假设函数为：y = wn * xn + w(n-1) * x(n-1) +... + w2 * x2 + w1 * x1 + w0 * x0 其中，x0为1\n",
    "    # X中：第一列为xn,第二列为x(n-1)，依次类推，最后一列为x0(全为1)\n",
    "    # W向量顺序是：wn,w(n-1),...w1,w0，要确保与X中各列顺序一致\n",
    "    W = init_W\n",
    "    target_value = target_fn(W, X, Y) \n",
    "    iter_count = 0\n",
    "    while iter_count < 50000:                      # 如果50000次循环仍未收敛，则认为无法收敛\n",
    "        gradient = gradient_fn(W, X, Y)\n",
    "        next_W = W - gradient * learning_rate \n",
    "        next_target_value = target_fn(next_W, X, Y)\n",
    "        if abs(target_value - next_target_value) < tolerance:\n",
    "            print(\"循环\", iter_count, \"次后收敛\")\n",
    "            return W\n",
    "        else:                                             \n",
    "            W, target_value = next_W, next_target_value\n",
    "            iter_count += 1\n",
    "    \n",
    "    print(\"50000次循环后，计算仍未收敛\")\n",
    "    return W"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    ">**步骤4：执行批量梯度下降操作，并检验拟合数据与样本数据的差异**\n",
    "* 尝试调整tolerance和learning_rate的值，查看循环所需的次数。一般应有下列情况：\n",
    "  * 给定learning_rate=0.01，调整tolerance(1e-5, 1e-6, 1e-7,1e-8等)，可看到，拟合精度越高，需要的循环次数越多\n",
    "  * 给定tolerance=1e-7，调整learning_rate(0.1, 0.01, 0.005, 0.001等)，可看到，如果rate太大，有可能直接溢出造成不收敛；rate太小，就需要更多的循环次数，甚至在规定的循环次数内无法收敛\n",
    "* 采用相同的tolerance和learning_rate时，无论运行多少次，每次的结果都是相同的"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "循环 30562 次后收敛\n",
      "W值 [ 2.99998771  3.99995426 -6.99980291]\n",
      "拟合曲线计算的数据值与真实值的差异： [ 3.18885104e-05  5.30614618e-05  7.42344132e-05 -6.63961846e-05\n",
      " -2.40502817e-05]\n"
     ]
    }
   ],
   "source": [
    "np.random.seed(0)\n",
    "init_W = np.random.randn(x_ext.shape[1])                # x_ext现在是5x3的矩阵，因此init_W需要3个初始元素\n",
    "learning_rate = 0.005                                   # 可调整learning_rate\n",
    "tolerance = 1e-12                                        # 可调整tolerance\n",
    "W = batch_gradient_descent(target_function, gradient_function, init_W, x_ext, y, learning_rate, tolerance)\n",
    "print(\"W值\", W)                                         # 此结果应与3, 4, -7接近\n",
    "\n",
    "print(\"拟合曲线计算的数据值与真实值的差异：\", x_ext.dot(W) - y)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### **随机梯度下降算法**\n",
    "计算梯度时，从训练数据中随机选择一组数据，仅针对该数据求导。假设选择的数据是$(x^{(k)}, y^{(k)})$，则：  \n",
    "$ \\dfrac{\\partial F(w)}{\\partial w_0}=2 \\cdot (w_0+ w_1 x^{(k)}-y^{(k)}) $  \n",
    "$ \\dfrac{\\partial F(w)}{\\partial w_1}=2 \\cdot (w_0+ w_1 x^{(k)}-y^{(k)}) \\cdot x^{(k)} $"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### **案例3：使用随机梯度下降法拟合多维数据**\n",
    ">下面的代码演示了如何使用随机梯度下降法来拟合多维数据，要注意：  \n",
    ">* learning_rate不能太大，否则无法收敛。\n",
    ">* 每次运行的结果可能略有差异"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "已完成循环次数： 50000\n",
      "[ 2.99950312  3.99816954 -6.99196235]\n",
      "拟合曲线计算的数据值与真实值的差异： [ 0.00139542  0.00223211  0.0030688  -0.00257965 -0.00090627]\n"
     ]
    }
   ],
   "source": [
    "''' 随机梯度下降算法进行多元线性拟合 '''\n",
    "\n",
    "def stochastic_gradient_descent(target_fn, gradient_fn, init_W, X, Y, learning_rate=0.001, tolerance=1e-12):\n",
    "    \"\"\"随机梯度下降法，本函数内部尝试使用了动态调整learning_rate\"\"\"\n",
    "    W = init_W\n",
    "    rate = learning_rate\n",
    "    min_W, min_target_value = None, float(\"inf\") \n",
    "    iter_count = 0\n",
    "    iterations_with_no_improvement = 0\n",
    "    m = len(X)\n",
    "    target_value = target_fn(W, X, Y)\n",
    "    while iter_count < 50000 and iterations_with_no_improvement < 100:  # 如果连续缩小学习速率100次进行计算，都没有计算出更低的结果值，则说明已经收敛\n",
    "        target_value = target_fn(W, X, Y)\n",
    "        iter_count += 1\n",
    "        if target_value < min_target_value:             # 计算出了更低的成本值\n",
    "            min_W, min_target_value = W, target_value\n",
    "            iterations_with_no_improvement = 0          # 因为找到了新的低值，重新开始计数\n",
    "            rate = learning_rate                        # rate也恢复到初始值\n",
    "        else:                                           # 未能计算出更低成本值，此时缩小学习速率再尝试，直到连续缩小100次\n",
    "            iterations_with_no_improvement += 1\n",
    "            rate *= 0.9\n",
    "        index = np.random.randint(0, m)                     # 获得一组随机数据的索引值\n",
    "        gradient = gradient_fn(W, X[index], Y[index])       # 计算该数据点处的导数\n",
    "        W = W - learning_rate * gradient\n",
    "    print(\"已完成循环次数：\", iter_count)\n",
    "    return min_W\n",
    "\n",
    "def target_function(W, X, Y):\n",
    "    temp = X.dot(W) - Y \n",
    "    return temp.dot(temp) / (2 * len(X))\n",
    "\n",
    "# 只需要计算一个样本点(xi)的梯度\n",
    "def gradient_function(W, xi, yi):\n",
    "    return 2 * xi * (xi.dot(W) - yi)\n",
    "\n",
    "# 下列数据是按照y=3*x1+4*x2-7函数生成的\n",
    "x_origin = np.array([[6, 2], [8, 1], [10, 0], [14, 2], [18, 0]])\n",
    "y = np.array([19, 21, 23, 43, 47])\n",
    "x_ext = np.c_[x_origin, np.ones(len(x_origin))]         # 追加全是1的数据列\n",
    "\n",
    "import time\n",
    "np.random.seed((int)(time.time()))                      # 每次运行程序产生不同的随机序列\n",
    "init_W = np.random.randn(x_ext.shape[1])\n",
    "# 执行随机梯度下降操作\n",
    "W = stochastic_gradient_descent(target_function, gradient_function, init_W, x_ext, y)      \n",
    "print(W)\n",
    "\n",
    "print(\"拟合曲线计算的数据值与真实值的差异：\", x_ext.dot(W) - y)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
